from math import exp
import os
from flask import Flask, request
import torch
import torch.nn as nn
from torchvision import models, transforms
from PIL import Image

classes = 6                                 #

app = Flask(__name__)

infer_transformation = transforms.Compose([       #图像预处理
    transforms.Resize(256),                       #重设大小
    transforms.CenterCrop(224),                   #裁剪中心区域
    transforms.ToTensor(),                        #数据转换
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])

LABEL_OUTPUT_KEY = 'predicted_label'
MODEL_OUTPUT_KEY = 'scores'


def decode_image(file_content):                     #图像转换
    image = Image.open(file_content)
    image = image.convert('RGB')
    return image


def read_label_list(path):                          #获取标签表
    with open(path, 'r', encoding="utf8") as f:
        label_list = f.read().split(os.linesep)
    label_list = [x.strip() for x in label_list if x.strip()]
    return label_list


def resnetrun(model_path):                          #模型信息配置

    model = models.resnet50(pretrained=False)
    num_ftrs = model.fc.in_features
    model.fc = nn.Linear(num_ftrs, classes)
    model.load_state_dict(torch.load(model_path, map_location='cpu'))

    model.eval()

    return model


@app.route('/predict', methods=['POST'])            #flask框架
def predict():                                      #预测模块
    if request.method == 'POST':                    #判断获取方式
        file_name = request.files['file']
    model = resnetrun(r'C:\Users\JoyboyPC\Desktop\model.pth') #使用训练好的模型
    image1 = decode_image(file_name)                #将POST到的图像转换成RGB格式
    input_img = infer_transformation(image1)        #预处理
    input_img = torch.autograd.Variable(torch.unsqueeze(input_img, dim=0).float(), requires_grad=False)
    logits_list = model(input_img)[0].detach().numpy().tolist()
    print(logits_list)
    maxlist = max(logits_list)                      #数据计算
    print(maxlist)
    z_exp = [exp(i - maxlist) for i in logits_list]
    LABEL_LIST = ["cat","crocodile","dog","dolphin","snake","sturgeon"]   #标签表
    sum_z_exp = sum(z_exp)
    softmax = [round(i / sum_z_exp, 3) for i in z_exp] #归一化
    print(softmax)
    labels_to_logits = {
        LABEL_LIST[i]: s for i, s in enumerate(softmax)
    }
    predict_result = {
        LABEL_OUTPUT_KEY: max(labels_to_logits, key=labels_to_logits.get),
        MODEL_OUTPUT_KEY: labels_to_logits
    }
    return predict_result


if __name__ == '__main__':
    app.run()